local Max_Level = 999

local Node = {}
Node.__cname = "util.SkipList.Node"
Node.__index = Node

function Node.new(k, v)
    local obj = {}
    setmetatable(obj, Node)
    obj:ctor(k, v)
    return obj
end

function Node:ctor(k, v)
    self.k = k
    self.v = v
    self.nextNodes = nil
end

function Node:SetNext(level, n)
    if nil == self.nextNodes then self.nextNodes = {} end
    self.nextNodes[level] = n
end

function Node:GetNext(level)
    if nil == self.nextNodes then self.nextNodes = {} end
    local nt = self.nextNodes[level]
    --if nil == nt then error("GetNext nil") end
    return nt
end


local SkipList = {}
SkipList.__cname = "util.SkipList"
SkipList.__index = SkipList

function SkipList.new(cmpFunc)
    assert(nil ~= cmpFunc)
    local obj = {}
    setmetatable(obj, SkipList)
    obj:ctor(cmpFunc)
    return obj
end

function SkipList:ctor(cmpFunc)
    self.cmpFunc = cmpFunc
    self.head = Node.new(nil, nil)
    self.count = 0
    self.levelCount = 0 --索引层数
end

function SkipList:Clear()
    if self.count < 1 then return end
    self.head.nextNodes = nil
    self.count = 0
    self.levelCount = 0
end

function SkipList:GetCount()
    return self.count
end

local function RandomLevel()
    local level = 1
    while math.random() < 0.5 and level <= Max_Level do
        level = level + 1
    end
    return level
end

function SkipList:Add(k, v)
    if self:Contains(k) then
        return false
    end
    local newLevel = RandomLevel()
    self.levelCount = math.max(self.levelCount, newLevel)

    local newNode = Node.new(k, v)
    local cur = self.head
    for level=newLevel,1,-1 do
        cur = self:_FindNext(k, cur, level)

        newNode:SetNext(level, cur:GetNext(level))
        cur:SetNext(level, newNode)
    end
    self.count = self.count + 1
    return true
end

function SkipList:Get(k)
    local node = self:_Find(k)
    if nil ~= node and 0 == self.cmpFunc(node.k, k) then
        return node.v
    end
    return nil
end

function SkipList:Remove(k)
    local ct = self.count
    if ct < 1 then return false end

    local cmpFunc = self.cmpFunc
    local removed = nil
    local cur = self.head
    local nt = nil
    for level=self.levelCount,1,-1 do
        nt = cur:GetNext(level)
        while nil ~= nt do
            local cmpResult = cmpFunc(k, nt.k)
            if cmpResult < 0 then break end --在cur, next之间, 索引下降一层

            if 0 == cmpResult then --nt为要删除的节点
                cur:SetNext(level, nt:GetNext(level)) --断开nt与前后节点的链接
                removed = nt
                break
            end
            --if k > nt.k then end --还在next的后面, 继续往后找
            cur = nt
            nt = cur:GetNext(level)
        end
    end
    if removed then
        --如果索引层级空了, 就下降
        for level=self.levelCount,1,-1 do
            if self.head:GetNext(level) then
                self.levelCount = level
                break
            end
        end

        self.count = self.count - 1
        return true, removed.v
    end
    return false
end

function SkipList:Contains(k)
    local node = self:_Find(k)
    return nil ~= node and node ~= self.head and 0 == self.cmpFunc(node.k, k)
end

function SkipList:_Find(k)
    local ct = self.count
    if ct < 1 then return nil end

    local cur = self.head
    for level=self.levelCount,1,-1 do
        cur = self:_FindNext(k, cur, level)
    end
    return cur
end

function SkipList:_FindNext(k, cur, level)
    local cmpFunc = self.cmpFunc
    local nt = cur:GetNext(level)
    while nil ~= nt do
        local cmpResult = cmpFunc(k, nt.k)
        if cmpResult < 0 then break end --在cur, next之间, 索引下降一层
        cur = nt
        nt = cur:GetNext(level)
    end
    return cur
end

function SkipList:__tostring()
    local ct = self.count
    if ct < 1 then return "" end
    local strTb = {}

    for level=1,self.levelCount do
        table.insert(strTb, tostring(level))
        table.insert(strTb, ": ")

        local cur = self.head
        local i = 0
        while cur do
            if nil ~= cur.k then
                if i > 1 then
                    table.insert(strTb, " -> ")
                end
                table.insert(strTb, tostring(cur.k))
                table.insert(strTb, "(")
                table.insert(strTb, cur.v)
                table.insert(strTb, ")")
            end
            cur = cur:GetNext(level)
            i = i + 1
        end

        table.insert(strTb, "\n")
    end

    return table.concat(strTb)
end

function SkipList:GetIterator()
    local curNode = self.head:GetNext(1)
    local function iterator(tb, key)
        local node = curNode
        if nil ~= node then
            curNode = curNode:GetNext(1)
            return node.k, node.v
        end
        return nil
    end
    return iterator
end

return SkipList